Adding your own custom high-fidelity solver¶
Advanced users may want to substitute the default solvers in rose for their own custom ones. This tutorial will walk through that process, with scattering of 14.1 MeV neutrons on $^{27}$Al, using the Koning-Delaroche optical potential for the proton-nucleus interaction.
We will add this solver, which uses the calculable R-matrix method on a Lagrange-Legendre mesh. In principle, this should provide a very precise solver, that is also capable of non-local and coupled-channels potentials.
import rose
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path
import pickle
from tqdm import tqdm
from matplotlib import colormaps as cm
from matplotlib import pyplot as plt
from matplotlib import rcParams
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator, StrMethodFormatter
from rose.training import multiple_formatter
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
plt.rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rcParams["text.latex.preamble"] = (
r"\usepackage{amssymb} \usepackage{amsmath} \usepackage{braket}"
)
plt.rc("text", usetex=True)
rcParams["legend.fontsize"] = 12
rcParams["font.size"] = 12
rcParams["font.weight"] = "normal"
rcParams["xtick.labelsize"] = 12.0
rcParams["ytick.labelsize"] = 12.0
rcParams["lines.linewidth"] = 2.0
rcParams["xtick.major.pad"] = "10"
rcParams["ytick.major.pad"] = "10"
rcParams["image.cmap"] = "BuPu"
# !pip install jitr
import jitr
from numba import njit
# set up kinematics
from rose.koning_delaroche import KDGlobal, Projectile
# for 27-Al
A = 27
Z = 13
# lab bombarding energy
E_lab = 35 # MeV
# get kinematics and default KD params
mu, E_com, k, eta = rose.utility.kinematics((A,Z), (1,0), E_lab=E_lab)
omp = rose.koning_delaroche.KDGlobal(Projectile.neutron)
R_C, parameters = omp.get_params(A, Z, mu, E_com, k)
# take train and test parameter samples from a box bounded
# 50% above and below the default values
scale = 0.25
bounds = np.array(
[
parameters - np.fabs(scale * parameters),
parameters + np.fabs(scale * parameters),
]
).T
from rose.training import sample_params_LHC
# sample points for a train/test split
# sample from +/-50% around the default params using Latin hypercube
n_train = 100
n_test = 100
training_samples = sample_params_LHC(n_train, parameters, scale, seed=13)
# grab some test parameters
test_samples = rose.training.sample_params_LHC(n_test, parameters, scale, seed=1233)
To use the jitr package as a custom solver in rose, we will need to use a derived class of rose.SchroedingerEquation: LagrangeRmatrix. This class is located in src/rose/lagrangelegendersolver.py, and can serve as a template example for interfacing external solvers in ROSE.
Let's set up the system:
nodes_within_radius = 3
s_0 = nodes_within_radius * (2 * np.pi) # channel matching radius
angles = np.linspace(1e-3, np.pi, 300)
s_mesh = np.linspace(1e-10, s_0, 1000)
domain = [s_mesh[0], s_0]
# create an interaction space for partial waves
interactions = rose.InteractionEIMSpace(
coordinate_space_potential=rose.koning_delaroche.KD_simple,
n_theta=len(parameters),
mu=mu,
energy=E_com,
is_complex=True,
spin_orbit_term=rose.koning_delaroche.KD_simple_so,
training_info=bounds,
Z_1=0,
Z_2=13,
R_C=R_C,
l_max=20,
expl_var_ratio_cutoff=1.0e-6,
rho_mesh=s_mesh,
)
rk_domain = [domain[0], domain[1] + np.pi] # extend rk domain past s_0
hifi_solver = rose.SchroedingerEquation.make_base_solver(
rk_tols=[1e-10, 1e-10], s_0=s_0, domain=rk_domain
)
emu = rose.ScatteringAmplitudeEmulator.from_train(
interactions,
training_samples,
base_solver=hifi_solver,
n_basis=20,
angles=angles,
s_mesh=s_mesh,
scale=True,
use_svd=True,
)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [04:38<00:00, 13.26s/it]
from matplotlib.lines import Line2D
f, ax = plt.subplots(1,1, figsize=(6,4), facecolor="white", dpi=600)
for sample in test_samples[::2]:
p = plt.plot(s_mesh, emu.rbes[0][0].interaction.tilde(s_mesh, sample), alpha=0.5)[0]
plt.plot(
s_mesh,
emu.rbes[0][0].interaction.tilde_emu(sample),
":",
alpha=0.5,
color=p.get_color(),
)
legend_styles = [
Line2D([0], [0], color="tab:gray", linestyle=":", alpha=0.8),
Line2D([0], [0], color="tab:gray", alpha=0.8),
]
leg = plt.legend(
legend_styles,
["EIM", "Exact",],
loc="lower right",
)
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi/2))
ax.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter(denominator=2)))
plt.ylabel(r"$U(s; \alpha)$")
plt.xlabel(r"$s = kr$")
plt.xlim([0,8])
/home/kyle/mambaforge/envs/om/lib/python3.10/site-packages/matplotlib/cbook.py:1762: ComplexWarning: Casting complex values to real discards the imaginary part return math.isfinite(val) /home/kyle/mambaforge/envs/om/lib/python3.10/site-packages/matplotlib/cbook.py:1398: ComplexWarning: Casting complex values to real discards the imaginary part return np.asarray(x, float)
(0.0, 8.0)
plt.plot(angles, emu.exact_dsdo(test_samples[4]), alpha=0.5, label="exact")
plt.plot(angles, emu.emulate_dsdo(test_samples[4]), ":", label="emu")
#plt.plot(angles, ground_truth_output[4], ":", label="gto")
plt.yscale("log")
plt.legend()
<matplotlib.legend.Legend at 0x7416cabf6aa0>
llrm_solver = rose.LagrangeRmatrix(
interactions.interactions[0][0],
s_0,
jitr.rmatrix.Solver(100),
)
llrm = rose.ScatteringAmplitudeEmulator.HIFI_solver(
interactions,
base_solver=llrm_solver,
angles=angles,
s_mesh=s_mesh,
)
# choose the highest fidelity Runge-Kutta solver as our 'ground truth'
rk_domain = [domain[0], domain[1] + np.pi] # extend rk domain past s_0
hifi_solver = rose.SchroedingerEquation.make_base_solver(
rk_tols=[1e-11, 1e-11], s_0=s_0, domain=rk_domain
)
ground_truth = rose.ScatteringAmplitudeEmulator.HIFI_solver(
interactions,
base_solver=hifi_solver,
angles=angles,
s_mesh=s_mesh,
)
And that's it, we have our high-fidelity Lagrange-Legendre R-matrix solver and a Runge-Kutta solver. Let's compare them.
solutions_llrm = llrm.exact_wave_functions(parameters)
solutions_rk = ground_truth.exact_wave_functions(parameters)
First, let's compare partial waves without phase matching:
from rose.training import compare_partial_waves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4), dpi=300)
fig.patch.set_facecolor("white")
compare_partial_waves(
s_mesh,
[solutions_rk[0:3], solutions_llrm[0:3]],
["Runge-Kutta", "Lagrange-Legendre"],
fig,
ax1,
ax2,
)
(<Figure size 2700x1200 with 2 Axes>,
<Axes: xlabel='$s = kr$ [dimensionless]', ylabel='$\\mathfrak{Re} \\, u_{lj}(s)$ [a.u.]'>,
<Axes: xlabel='$s = kr$ [dimensionless]', ylabel='$\\mathfrak{Im} \\, u_{lj}(s)$ [a.u.]'>)
Now, let's match phases and look again:
#TODO wavefunction normalization convention
for l in range(len(solutions_rk)):
for j in range(len(solutions_rk[l])):
i = 20
solutions_llrm[l][j] *= (
solutions_rk[l][j][i] / solutions_llrm[l][j][i])
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4), dpi=300)
fig.patch.set_facecolor("white")
compare_partial_waves(
s_mesh,
[solutions_rk[0:3], solutions_llrm[0:3]],
["Runge-Kutta", "Lagrange-Legendre"],
fig,
ax1,
ax2,
)
(<Figure size 2700x1200 with 2 Axes>,
<Axes: xlabel='$s = kr$ [dimensionless]', ylabel='$\\mathfrak{Re} \\, u_{lj}(s)$ [a.u.]'>,
<Axes: xlabel='$s = kr$ [dimensionless]', ylabel='$\\mathfrak{Im} \\, u_{lj}(s)$ [a.u.]'>)
Now, let's take a look at phase shifts:
deltas_rk = ground_truth.exact_phase_shifts(parameters)
deltas_llrm = llrm.exact_phase_shifts(parameters)
from rose.training import plot_phase_shifts, compare_phase_shifts_err
from matplotlib import ticker
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
2, 2, sharex=True, figsize=(9, 5), height_ratios=[1, 0.5], dpi=300
)
fig.patch.set_facecolor("white")
compare_phase_shifts_err(
deltas_rk,
deltas_llrm,
"Runge-Kutta",
"Lagrange-Legendre",
fig,
ax1,
ax2,
ax3,
ax4,
small_label1="RK",
small_label2="N",
)
plt.tight_layout()
rk_tolerances = [(1e-5, 1e-5), (1e-7, 1e-7), (1e-9, 1e-9)]
rk_solvers = []
for rk_tols in rk_tolerances:
solver = rose.SchroedingerEquation.make_base_solver(
rk_tols=rk_tols, s_0=s_0, domain=rk_domain
)
rk_solvers.append(
rose.ScatteringAmplitudeEmulator.HIFI_solver(
interactions, base_solver=solver, angles=angles, s_mesh=s_mesh
)
)
nbases = [30, 50, 70]
llrm_solvers = []
for n in nbases:
solver = rose.LagrangeRmatrix(
interactions.interactions[0][0],
s_0,
jitr.rmatrix.Solver(n)
)
llrm_solvers.append(
rose.ScatteringAmplitudeEmulator.HIFI_solver(
interactions,
base_solver=solver,
angles=angles,
s_mesh=s_mesh,
)
)
Not bad! How does this translate to cross sections? Let's calculate $d\sigma/d\Omega$ for various instances of our interaction with different parameters.
%%time
# calculate ground truth differential cross sections for each sample
fpath = Path("ground_truth_rm_vs_rk_test.pkl")
if fpath.is_file():
with open(fpath, 'rb') as f:
ground_truth_output = pickle.load(f)
else:
ground_truth_output = []
for sample in tqdm(test_samples):
ground_truth_output.append(ground_truth.exact_xs(sample).dsdo)
with open(fpath, 'wb') as f:
pickle.dump(ground_truth_output, f)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [06:19<00:00, 3.79s/it]
CPU times: user 6min 19s, sys: 3 s, total: 6min 22s Wall time: 6min 19s
%%time
# calculate lowfi differential cross sections for each sample to compare
lowfi_solver = llrm_solvers[1]
lowfi_method = "Lagrange-Legendre n={}".format(nbases[1])
lowfi_output = []
for sample in tqdm(test_samples):
lowfi_output.append(lowfi_solver.exact_xs(sample).dsdo)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 132.65it/s]
CPU times: user 756 ms, sys: 2 ms, total: 758 ms Wall time: 756 ms
from matplotlib.lines import Line2D
from rose.training import multiple_formatter
fig, ax = plt.subplots(figsize=(9, 4), dpi=300)
fig.patch.set_facecolor("white")
# only plot a few
n2plot = 10
for xs_hifi, xs_lowfi in zip(ground_truth_output[:n2plot], lowfi_output[:n2plot]):
p = ax.plot(angles, xs_hifi, alpha=0.5)[0]
ax.plot(angles, xs_lowfi, "--", color=p.get_color(), alpha=0.5)
legend_styles = [
Line2D([0], [0], color="tab:gray", linestyle="--", alpha=0.8),
Line2D([0], [0], color="tab:gray", alpha=0.8),
]
leg = plt.legend(
legend_styles,
[lowfi_method, "Runge-Kutta [1e-11, 1e-11]",],
loc="lower left",
)
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi/6))
ax.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter(denominator=6)))
ax.set_yscale("log")
ax.set_ylabel(r"$ d \sigma / d\Omega $ [mb/Sr]")
ax.set_xlabel(r"$\theta$ [radians]")
Text(0.5, 0, '$\\theta$ [radians]')
Looks like $n=50$ is pretty darn good, but it does have a few cases with minor disagreements to the high fidelity. Let's quantify this with a CAT plot:
from rose.training import CATPerformance
%%time
rk_solver_performances = []
for solver, tols in zip(rk_solvers, rk_tolerances):
label = r"$\left[{:1.0e} , {:1.0e}\right]$".format(*tols)
fpath = Path("perf_rk_{:1.0e}_{:1.0e}.pkl".format(*tols))
if fpath.is_file():
with open(fpath, 'rb') as f:
perf = pickle.load(f)
rk_solver_performances.append(perf)
else:
rk_solver_performances.append(
CATPerformance(
benchmark_runner=lambda sample: solver.exact_xs(sample).dsdo,
benchmark_inputs=test_samples,
benchmark_ground_truth=ground_truth_output,
label=label,
)
)
with open(fpath, 'wb') as f:
pickle.dump(rk_solver_performances[-1], f)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:26<00:00, 3.80it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:00<00:00, 1.66it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:30<00:00, 1.50s/it]
CPU times: user 3min 57s, sys: 3.83 s, total: 4min Wall time: 3min 56s
%%time
llrm_solver_performances = []
for solver, n in zip(llrm_solvers, nbases):
label = r"$N = {}$".format(n)
fpath = Path("perf_llrm_N_{}.pkl".format(n))
if fpath.is_file():
with open(fpath, "rb") as f:
perf = pickle.load(f)
llrm_solver_performances.append(perf)
else:
llrm_solver_performances.append(
CATPerformance(
benchmark_runner=lambda sample: solver.exact_xs(sample).dsdo,
benchmark_inputs=test_samples,
benchmark_ground_truth=ground_truth_output,
label=label,
)
)
with open(fpath, "wb") as f:
pickle.dump(llrm_solver_performances[-1], f)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 238.10it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 139.88it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.90it/s]
CPU times: user 9.07 s, sys: 18 s, total: 27 s Wall time: 4.49 s
# [basis size, number of terms in EIM decomposition]
sae_configs = [(5, 5), (10, 10), (15, 15), (20, 20)]
_, emulators = rose.training.build_sae_config_set(
sae_configs,
interactions,
training_samples,
bounds,
angles=angles,
base_solver=hifi_solver,
s_mesh=s_mesh,
scale=True,
use_svd=True,
)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [07:12<00:00, 20.59s/it] 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [02:24<00:00, 36.11s/it]
emu = emulators[-1]
plt.plot(angles, emu.exact_dsdo(test_samples[4]), alpha=0.5, label="exact")
plt.plot(angles, emu.emulate_dsdo(test_samples[4]), label="emu")
plt.plot(angles, ground_truth_output[4], ":", label="gto")
plt.yscale("log")
plt.legend()
<matplotlib.legend.Legend at 0x7416ca7f7610>
rbm_performances = []
for solver, config in zip(emulators, sae_configs):
label = r"$\left({:d} , {:d}\right)$".format(*config)
fpath = Path("perf_rbm_{:d}_{:d}.pkl".format(*config))
if fpath.is_file():
with open(fpath, "rb") as f:
perf = pickle.load(f)
rbm_performances.append(perf)
else:
rbm_performances.append(
CATPerformance(
benchmark_runner=lambda sample: solver.emulate_xs(sample).dsdo,
benchmark_inputs=test_samples,
benchmark_ground_truth=ground_truth_output,
label=label,
)
)
with open(fpath, "wb") as f:
pickle.dump(rbm_performances[-1], f)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 395.88it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 372.39it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 331.67it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 248.11it/s]
# plot the Computational Accuracy vs. Time (CAT)
from rose.training import CAT_plot
fig, ax = CAT_plot(
[rk_solver_performances, llrm_solver_performances, rbm_performances[2:]],
labels=["Runge-Kutta [rtol, atol]", "Lagrange-Legendre", "RBM [$n_{rbm}$, $n_{eim}$]"],
border_styles=[":", "--", "-"],
)
ax.set_ylabel(r"median relative error in $d \sigma / d \Omega$ [\%]", size=16)
ax.set_xlabel(r"time per sample [s]", size=16)
ax.set_ylim(1e-10,10)
ax.set_xlim(1e-3,3)
plt.tight_layout()
plt.savefig("cat_rbm_rm_rk.pdf")
Great! It looks like using this Lagrange-Legendre method could save us a lot of time in training. For most purposes, $n \sim 100$ should be good, with less than 0.1% median relative error across all samples, and a ~1-1.5 order of magnitude speed up from the Runge-Kutta solvers.
# plot the Computational Accuracy vs. Time (CAT)
from rose.training import CAT_plot
fig, ax = CAT_plot(
[rk_solver_performances, llrm_solver_performances],
labels=["Runge-Kutta [rtol, atol]", "Lagrange-Legendre"],
border_styles=[":", "-"],
)
ax.set_ylabel(r"median relative error in $d \sigma / d \Omega$ [%]")
ax.set_ylim(1e-10,100)
ax.set_xlim(2e-3,10)
(0.002, 10)